#!/usr/bin/env python3
import dataclasses
from typing import Any

import numpy as np
import rclpy
import tf2_geometry_msgs
import tf_transformations
import yaml
from geometry_msgs.msg import (
    Pose,
    PoseStamped,
    PoseWithCovarianceStamped,
    Transform,
    TransformStamped,
)
from pyquaternion import Quaternion
from rclpy.duration import Duration
from rclpy.node import Node
from rclpy.parameter import Parameter
from rclpy.qos import QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy
from tf2_msgs.msg import TFMessage
from tf2_ros import TransformException
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener

from visual_localization.ekf import (
    EkfParams,
    ExtendedKalmanFilter,
)


def get_tag_id(frame_id: str) -> int:
    return int(frame_id.removeprefix('tag36h11:'))


def transform_to_pose(tf: Transform):
    pose = Pose()
    pose.position.x = tf.translation.x
    pose.position.y = tf.translation.y
    pose.position.z = tf.translation.z

    pose.orientation.w = tf.rotation.w
    pose.orientation.x = tf.rotation.x
    pose.orientation.y = tf.rotation.y
    pose.orientation.z = tf.rotation.z
    return pose


class EkfNode(Node):
    def __init__(self, node_name):
        super().__init__(node_name=node_name)

        self.declare_parameter('camera_name', rclpy.Parameter.Type.STRING)
        self.declare_parameter('vehicle_name', rclpy.Parameter.Type.STRING)
        self.camera_name = str(self.get_parameter('camera_name').value)
        self.vehicle_name = self.get_parameter('vehicle_name').value

        self.ekf_params = self.init_ekf_params()
        self.ekf = ExtendedKalmanFilter(self.ekf_params)
        self.dim_meas = 2  # distance, yaw to tags

        self.tf_buffer = Buffer()
        self.tf_listener = TransformListener(self.tf_buffer, self)

        self.initialized = False
        self.camera_frame_to_base_link_transform = None

        self.t_last_prediction = self.get_clock().now()
        self.tags = self.load_tag_poses()

        self.vision_pose_pub = self.create_publisher(
            PoseWithCovarianceStamped,
            f'/{self.vehicle_name}/vision_pose_cov',
            qos_profile=1,
        )

        # publish all ekf states? -> TODO
        qos = QoSProfile(
            reliability=QoSReliabilityPolicy.BEST_EFFORT,
            history=QoSHistoryPolicy.KEEP_LAST,
            depth=1,
        )

        self.tag_transforms_sub = self.create_subscription(
            TFMessage, 'tag_transforms', self.on_tag_transforms, qos
        )

        # TODO: Update this subscriber!
        self.orientation_sub = self.create_subscription(
            PoseStamped, 'mavros/local_position/pose', self.on_orientation, qos
        )

        self.predict_timer = self.create_timer(
            timer_period_sec=(1 / 30), callback=self.predict
        )

        self.send_vision_pose_timer = self.create_timer(
            timer_period_sec=(1 / 30), callback=self.send_vision_pose
        )

    def init_ekf_params(self):
        def recursion(data: Any, name: str):
            if not dataclasses.is_dataclass(data):
                if isinstance(data, int):
                    param_type = Parameter.Type.INTEGER
                elif isinstance(data, float):
                    param_type = Parameter.Type.DOUBLE
                elif isinstance(data, str):
                    param_type = Parameter.Type.STRING
                else:
                    raise TypeError(f'Unsupported type "{type(data)}"')
                self.declare_parameter(name, value=param_type)
                return self.get_parameter(name).value
            for field_name in data.__annotations__.keys():
                setattr(
                    data,
                    field_name,
                    recursion(
                        getattr(data, f'{field_name}'), f'{name}.{field_name}'
                    ),
                )
            return data

        ekf_params = recursion(EkfParams(), 'ekf_params')
        return ekf_params

    def init_tf_transforms(self):
        # look up transform from camera frame to base link
        def get_prefix(node):
            name = self.get_parameter('vehicle_name').value
            name = name.strip('/')
            return name

        prefix = get_prefix(self)
        camera_frame_name = prefix + '/' + self.camera_name + '_frame'
        camera_link_name = prefix + '/' + self.camera_name + '_link'
        base_link_name = prefix + '/base_link'
        timeout = Duration(seconds=2.0)
        try:
            self.camera_frame_to_base_link_transform = (
                self.tf_buffer.lookup_transform(
                    target_frame=base_link_name,
                    source_frame=camera_frame_name,
                    time=rclpy.time.Time(),
                    timeout=timeout,
                )
            )
            self.camera_frame_to_camera_link_transform = (
                self.tf_buffer.lookup_transform(
                    target_frame=camera_link_name,
                    source_frame=camera_frame_name,
                    time=rclpy.time.Time(),
                    timeout=timeout,
                )
            )
            return True
        except TransformException as e:
            self.get_logger().warning(
                f'After waiting {timeout}, could not look up transform from'
                ' camera frame to base link for vehicle'
            )
            self.get_logger().warn(f'{e}')
            return False

    def on_tag_transforms(self, msg: TFMessage):
        if not self.initialized:
            self.initialized = self.init_tf_transforms()
            if not self.initialized:
                return
            else:
                self.get_logger().info('Initialized TF transforms for EKF.')

        num_detected_tags = len(msg.transforms)
        if not num_detected_tags:
            return  # no tags detected, nothing to update

        measurements = np.zeros((num_detected_tags * self.dim_meas, 1))
        detected_tags = np.zeros((num_detected_tags, 4))

        tf_stamped: TransformStamped
        for i, tf_stamped in enumerate(msg.transforms):
            tag_id = get_tag_id(tf_stamped.child_frame_id)

            # add this tag's information to list of detected tags
            index = np.where(self.tags[:, 0] == tag_id)
            if len(index[0]) != 1:
                self.get_logger().error(
                    'Expected to find exactly one occurance of '
                    f'tag with id={tag_id}. But found occurances at {index}'
                )
                continue
            detected_tags[i, :] = self.tags[index, 0:4]

            tag_pose_local = Pose()
            tag_pose_local.orientation.w = 1.0

            # tag_pose_in_camera_frame = tf2_geometry_msgs.do_transform_pose(
            #     tag_pose_local, tf_stamped)
            tag_pose_in_camera_frame = transform_to_pose(tf_stamped.transform)
            tf2_geometry_msgs.do_transform_pose(
                tag_pose_in_camera_frame,
                self.camera_frame_to_camera_link_transform,
            )

            tag_pose_in_base_link = tf2_geometry_msgs.do_transform_pose(
                tag_pose_in_camera_frame,
                self.camera_frame_to_base_link_transform,
            )

            position_tag_in_base_link = np.array(
                [
                    tag_pose_in_base_link.position.x,
                    tag_pose_in_base_link.position.y,
                    tag_pose_in_base_link.position.z,
                ]
            ).reshape(-1, 1)
            q_meas = [
                tag_pose_in_base_link.orientation.x,
                tag_pose_in_base_link.orientation.y,
                tag_pose_in_base_link.orientation.z,
                tag_pose_in_base_link.orientation.w,
            ]

            # inverse quaternion to get base_link orientation
            # in tag frame, which has same orientation as map frame
            q_meas = tf_transformations.quaternion_inverse(q_meas)
            q_meas = Quaternion(
                x=q_meas[0], y=q_meas[1], z=q_meas[2], w=q_meas[3]
            )
            q_truth = self.tags[index, 4:].flatten().tolist()
            q_truth = Quaternion(
                x=q_truth[0], y=q_truth[1], z=q_truth[2], w=q_truth[3]
            )
            q = q_truth * q_meas
            # TODO(lennartalff): THIS IS WRONG! We want the yaw angle relative
            # to the world frame and not relative to the tag!
            _, _, yaw = tf_transformations.euler_from_quaternion(
                [q[1], q[2], q[3], q[0]]
            )

            # measurement 1: distance tag - base_link
            measurements[self.dim_meas * i, 0] = np.linalg.norm(
                position_tag_in_base_link
            )
            # measurement 2: yaw angle in map frame
            measurements[self.dim_meas * i + 1, 0] = yaw

        self.ekf.update_vision_data(measurements, detected_tags)

    def on_orientation(self, msg: PoseStamped):
        orientation = [
            msg.pose.orientation.x,
            msg.pose.orientation.y,
            msg.pose.orientation.z,
            msg.pose.orientation.w,
        ]
        roll, pitch, _ = tf_transformations.euler_from_quaternion(orientation)
        measurements = np.array([roll, pitch]).reshape((-1, 1))

        self.ekf.update_orientation_data(measurements)

    def predict(self):
        now = self.get_clock().now()

        dt = now - self.t_last_prediction
        dt_seconds = dt.nanoseconds / 1e9
        self.ekf.predict(dt_seconds)

        self.t_last_prediction = now

    def send_vision_pose(self):
        x_est = self.ekf.get_x_est()
        cov_est = self.ekf.get_p_mat()[:6, :6]
        position = x_est[:3, 0]
        roll = x_est[3]
        pitch = x_est[4]
        yaw = x_est[5]
        quat = tf_transformations.quaternion_from_euler(roll, pitch, yaw)

        msg = PoseWithCovarianceStamped()
        msg.header.stamp = self.get_clock().now().to_msg()
        msg.header.frame_id = 'map'
        msg.pose.pose.position.x = position[0]
        msg.pose.pose.position.y = position[1]
        msg.pose.pose.position.z = position[2]
        msg.pose.pose.orientation.x = quat[0]
        msg.pose.pose.orientation.y = quat[1]
        msg.pose.pose.orientation.z = quat[2]
        msg.pose.pose.orientation.w = quat[3]
        msg.pose.covariance = np.ndarray.tolist(np.ndarray.flatten(cov_est))
        self.vision_pose_pub.publish(msg)

    def load_tag_poses(self):
        self.declare_parameter('tag_poses_file', value=Parameter.Type.STRING)
        filepath = self.get_parameter('tag_poses_file').value
        if not isinstance(filepath, str):
            print('no value for tag_poses_file param')
            exit(1)
        with open(filepath, 'r') as f:
            data = yaml.safe_load(f)
            num_tags = len(data['tag_poses'])
            # iterate over tags:
            tags = np.zeros((num_tags, 8))
            for tag in data['tag_poses']:
                if tag['frame_id'] == 'map':
                    tags[tag['id'], :] = np.array(
                        [
                            tag['id'],
                            tag['x'],
                            tag['y'],
                            tag['z'],
                            tag['qx'],
                            tag['qy'],
                            tag['qz'],
                            tag['qw'],
                        ]
                    )
                else:
                    print('Tag not in map frame! - not implemented yet')
        return tags


def main():
    rclpy.init()
    node = EkfNode('ekf_node')
    rclpy.spin(node)


if __name__ == '__main__':
    main()
